#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Nov  9 11:21:56 2020

@author: agyeman
"""

from casadi import *
import numpy as np

from Parameters_1D import *

totalDepth, axialNodes, axialDistance, totalNodes, numOfActuators, interPoints = spatialVariables_1D()  #Spatial parameters


selectMatrix=np.zeros((totalNodes,totalNodes))
selectMatrix[totalNodes-1, totalNodes-1] = 1.0

def inputConstraint(i, ulb,uub):
    #ULB=ulb
    #UUB=uub
    
    if i==0:
        ULB=ulb
        UUB=uub
    else:
        ULB=[0]
        UUB=[0]
    
    return ULB, UUB
    


def ZoneMPC_multiShooting(currentState,referenceTraj,refInput, Integrator, Q_matrix, 
                          R_matrix,QUpper,QLower, UBoundState, LBoundState, 
                          UBoundCtrl, LBoundCtrl,lowerZone, upperZone,LBSlack, UBSlack,cropCoeff, refEvap):
    
    w=[] # List to hold the decision variables
    lbw=[] # lower bound on the decision variables
    ubw=[] # upper bound on the decision variables
    G=[]  # Defect constraints/ Continuity constraints
    J=0   # cost function
    Guess = []
    lbg=[]
    ubg=[]
    

    #Initial Conditions(Initial State)
    Xk=MX.sym('X0',totalNodes)
    w+=[Xk]
    
    currenStateNew= list(currentState)
    lbw+=currenStateNew
    ubw+=currenStateNew
    Guess+=list(referenceTraj)
    
    for k in range(1, interPoints+1):

        Uname = 'U'+str(k-1)
        Uk=MX.sym(Uname, numOfActuators)
        w+=[Uk]
        #lbControl, ubControl = inputConstraint(k-1,LBoundCtrl, UBoundCtrl)
        lbw+=list(LBoundCtrl)
        ubw+=list(UBoundCtrl)
        Guess+=list(refInput)
        Fk=Integrator(x0=Xk,p=vertcat(Uk, cropCoeff, refEvap))
        
        #J+=mtimes((Xk-referenceTraj[k-1,:]).T, mtimes(Q_matrix, (Xk-referenceTraj[k-1,:])))
        #J+=mtimes((Uk-refInput).T, mtimes(R_matrix, (Uk-refInput)))
        
        #errX = mtimes(selectMatrix, Xk-referenceTraj)
        #J+=mtimes(errX.T, mtimes(Q_matrix, errX))
        errU = Uk
        J+=mtimes(errU.T, mtimes(R_matrix, errU))
        
        #J+=Uk



        
        #New state
        
        Xname='X'+ str(k)
        Xk=MX.sym(Xname,totalNodes)
        w+=[Xk]
        lbw+=list(LBoundState)
        ubw+=list(UBoundState)
        Guess+=list(referenceTraj)
        
        eLName = 'el'+str(k)
        ekL=MX.sym(eLName, totalNodes)
        
        eUName ='eu'+str(k)
        ekU =MX.sym(eUName, totalNodes)
        
        w+=[ekL]
        lbw+=list(LBSlack)
        ubw+=list(UBSlack)
        Guess+=list(np.zeros(totalNodes))
        
        w+=[ekU]
        lbw+=list(LBSlack)
        ubw+=list(UBSlack)
        Guess+=list(np.zeros(totalNodes))
        
        G+=[Fk['xf']-Xk]
        lbg+=list(np.zeros(totalNodes))
        ubg+=list(np.zeros(totalNodes))
        
        #DiffZone = list(lowerZone+upperZone)
        #G+=[Xk+ekL-ekU-DiffZone]
        G+=[Xk+ekL-ekU]
        lbg+=list(lowerZone)
        ubg+=list(upperZone)
        
        J+=mtimes(ekL.T, mtimes(QLower, ekL)) + mtimes(ekU.T, mtimes(QUpper, ekU))
    return w, lbw, ubw, G, J, Guess,lbg,ubg



# def controlAndOL_response(x):
#     x_new=x.full().ravel()
    
#     #assert x_new.shape[0] == totalNodes*(interPoints+1) + numOfActuators*interPoints
    
    
#     sp_1 = (totalNodes + numOfActuators)-1
#     sp_2 = numOfActuators*interPoints
#     sp_3 = totalNodes + numOfActuators
    
#     u=[]
    
#     for j in range(sp_2):
#         index = sp_1 + j*sp_3
#         u.append(x_new[index])
    
    
        
#     x_OL=np.zeros((interPoints+1, totalNodes))
    
#     for k in range(interPoints+1):
#         x_OL[k,:]=x_new[k*sp_3:k*sp_3 + totalNodes]
     
#     return x_OL, np.array(u)




def ctrlOLResponse_zone(x):
    
    
    sp_1 = (totalNodes + numOfActuators)-1
    sp_2 = numOfActuators*interPoints
    sp_3 = totalNodes + numOfActuators
    sp_4 = 2*totalNodes + numOfActuators
    sp_5 = 3*totalNodes + numOfActuators
    sp_6 = 4*totalNodes + numOfActuators
    
    
    x_new=x.full().ravel()
    xres=x_new[0:sp_6]
    
    u=[]
    
    stateOL= np.zeros((2, totalNodes))
    slackLW = np.zeros((1, totalNodes))
    slackUP = np.zeros((1,totalNodes))
    
    
    for j in range(1):
        index=sp_1 +j*sp_4
        u.append(xres[index])
    
    for k in range(2):
        stateOL[k,:] = xres[k*sp_3: k*sp_3 + totalNodes]
    
    for k in range(1):
        slackLW[k,:]=xres[(k+1)*sp_4:(k+1)*sp_4 + totalNodes]
    
    for k in range(1):
        slackUP[k,:] = xres[(k+1)*sp_5:(k+1)*sp_5 + totalNodes]
    
    return stateOL, np.array(u), slackLW, slackUP

    
    
def createAndSolveNLP(initialGuess, w, lbw, ubw, G, J, lbg, ubg):
    nlp=dict(f=J, g=vertcat(*G), x=vertcat(*w))
    S=nlpsol('S','ipopt',nlp)
    S.stats()
    r=S(lbx=lbw, ubx=ubw,x0=initialGuess,lbg=lbg,ubg=ubg)
    [stateOL, u, slackLW, slackUP] = ctrlOLResponse_zone(r['x'])
    
    #[x_ol, u] = controlAndOL_response(r['x'])
    
    return stateOL,u,slackLW, slackUP






